//----------------------------------------------------------------------------
//
// Copyright (C) Sartorius Stedim Data Analytics AB 2017 -
//
// Use, modification and distribution are subject to the Boost Software
// License, Version 1.0. (See http://www.boost.org/LICENSE_1_0.txt)
//
//----------------------------------------------------------------------------

// This is an example program for using the COM interface of SIMCA-Q dll. To build and 
// run this application you must copy SIMCAQ.tlb to the same directory where you have 
// put the source files.
//
// You must also register the SIMCA-Q.dll with regsvr32.

#include "StdAfx.h"
#include "SQPPlusCOMSample.h"
#include "SQPPlusCOMSampleDlg.h"

/////////////////////////////////////////////////////////////////////////////
// CAboutDlg dialog used for App About

class CAboutDlg : public CDialog
{
public:
   CAboutDlg() : CDialog(CAboutDlg::IDD) {}

   enum { IDD = IDD_ABOUTBOX };
};

/////////////////////////////////////////////////////////////////////////////
// Utilities

namespace
{
   bool IsPredictiveModel(ModelType eModelType)
   {
      return (eModelType == ePLS ||
         eModelType == ePLS_DA ||
         eModelType == ePLS_Class ||
         eModelType == eOPLS ||
         eModelType == eOPLS_DA ||
         eModelType == eOPLS_Class ||
         eModelType == eO2PLS ||
         eModelType == eO2PLS_DA ||
         eModelType == eO2PLS_Class
         );
   }
}


/////////////////////////////////////////////////////////////////////////////
// CSQPPlusCOMSampleDlg dialog

CSQPPlusCOMSampleDlg::CSQPPlusCOMSampleDlg(CWnd* pParent /*=NULL*/)
   : CDialog(CSQPPlusCOMSampleDlg::IDD, pParent)
   , m_hIcon(AfxGetApp()->LoadIcon(IDR_MAINFRAME))
{
}

CSQPPlusCOMSampleDlg::~CSQPPlusCOMSampleDlg()
{
}

void CSQPPlusCOMSampleDlg::DoDataExchange(CDataExchange* pDX)
{
   CDialog::DoDataExchange(pDX);
}

BEGIN_MESSAGE_MAP(CSQPPlusCOMSampleDlg, CDialog)
   ON_WM_SYSCOMMAND()
   ON_WM_PAINT()
   ON_WM_QUERYDRAGICON()
   ON_BN_CLICKED(IDC_BUTTON1, OnStart)
END_MESSAGE_MAP()

/////////////////////////////////////////////////////////////////////////////
// CSQPPlusCOMSampleDlg message handlers

BOOL CSQPPlusCOMSampleDlg::OnInitDialog()
{
   CDialog::OnInitDialog();

   // Add "About..." menu item to system menu.

   // IDM_ABOUTBOX must be in the system command range.
   ASSERT((IDM_ABOUTBOX & 0xFFF0) == IDM_ABOUTBOX);
   ASSERT(IDM_ABOUTBOX < 0xF000);

   CMenu* pSysMenu = GetSystemMenu(FALSE);
   if (pSysMenu != nullptr)
   {
      CString strAboutMenu;
      strAboutMenu.LoadString(IDS_ABOUTBOX);
      if (!strAboutMenu.IsEmpty())
      {
         pSysMenu->AppendMenu(MF_SEPARATOR);
         pSysMenu->AppendMenu(MF_STRING, IDM_ABOUTBOX, strAboutMenu);
      }
   }

   // Set the icon for this dialog.  The framework does this automatically
   //  when the application's main window is not a dialog
   SetIcon(m_hIcon, TRUE);			// Set big icon
   SetIcon(m_hIcon, FALSE);		// Set small icon

   // TODO: Add extra initialization here

   return TRUE;  // return TRUE  unless you set the focus to a control
}

void CSQPPlusCOMSampleDlg::OnSysCommand(UINT nID, LPARAM lParam)
{
   if ((nID & 0xFFF0) == IDM_ABOUTBOX)
   {
      CAboutDlg dlgAbout;
      dlgAbout.DoModal();
   }
   else
   {
      CDialog::OnSysCommand(nID, lParam);
   }
}

// If you add a minimize button to your dialog, you will need the code below
//  to draw the icon.  For MFC applications using the document/view model,
//  this is automatically done for you by the framework.

void CSQPPlusCOMSampleDlg::OnPaint()
{
   if (IsIconic())
   {
      CPaintDC dc(this); // device context for painting

      SendMessage(WM_ICONERASEBKGND, (WPARAM)dc.GetSafeHdc(), 0);

      // Center icon in client rectangle
      int cxIcon = GetSystemMetrics(SM_CXICON);
      int cyIcon = GetSystemMetrics(SM_CYICON);
      CRect rect;
      GetClientRect(&rect);
      int x = (rect.Width() - cxIcon + 1) / 2;
      int y = (rect.Height() - cyIcon + 1) / 2;

      // Draw the icon
      dc.DrawIcon(x, y, m_hIcon);
   }
   else
   {
      CDialog::OnPaint();
   }
}

// The system calls this to obtain the cursor to display while the user drags
//  the minimized window.
HCURSOR CSQPPlusCOMSampleDlg::OnQueryDragIcon()
{
   return (HCURSOR)m_hIcon;
}

void CSQPPlusCOMSampleDlg::OnStart()
{
   CString strErr;
   CString strUspPath;             // Will hold the full path to the .usp file to add.

   wchar_t szTempPath[_MAX_DIR];
   GetTempPathW(_MAX_DIR, szTempPath);

   CString strOutputFile = szTempPath + CString(L"SIMCA-Q COM result.txt");

   try
   {
      _bstr_t  strProjectName;         // Will contain the observation level project name

      CLSID clsSQ = {0};
      if ((CLSIDFromProgID(L"Umetrics.SIMCAQ", &clsSQ)) != S_OK)
         return;

      auto hr = mpSQ.CreateInstance(clsSQ);
      if (hr != S_OK)
         throw("CreateInstance failed");

      // mpFile is a temp file where we store the results from our predictions
      mpFile = _wfopen(strOutputFile.GetString(), L"w, ccs=UTF-8");
      if (mpFile == nullptr)
         throw("Could not open the result file.");

      CFileDialog oDlg(true, nullptr, nullptr,
         OFN_HIDEREADONLY | OFN_OVERWRITEPROMPT | OFN_FILEMUSTEXIST | OFN_ENABLESIZING,
         L"SIMCA Project (*.usp)|*.usp|All files (*.*)|*.*||");

      oDlg.m_ofn.lpstrTitle = L"Open SIMCA Project";

      if (oDlg.DoModal() != IDOK)
         return;

      CWaitCursor oWait;

      // Now an *.usp file has been selected.
      strUspPath = oDlg.GetPathName();

      fwprintf(mpFile, L"Path to usp file:        %s\n", strUspPath.GetString());

      mpProject = mpSQ->OpenProject(_bstr_t(strUspPath), _bstr_t(""));

      // Check if the project is a batch project.
      mbIsBatchProject = mpProject->GetIsBatchProject();

      // Get the name of the project.
      strProjectName = mpProject->GetProjectName();
      fwprintf(mpFile, L"Project name:            %s\n", static_cast<const wchar_t*>(strProjectName));

      if (mbIsBatchProject)
      {
         ProcessBatchProject();
      }
      else
      {
         ProcessContinousProject();
      }
   }
   catch (const char* szError)
   {
      strErr = szError;
   }
   catch (const _com_error& err)
   {
      strErr.Append(mpSQ->GetErrorDescription(err.Error()));
   }
   catch (...)
   {
      strErr = L"Unknown error";
   }

   if (mpFile)
   {
      fclose(mpFile);
      mpFile = nullptr;
   }

   if (!strErr.IsEmpty())
   {
      AfxMessageBox(strErr);
   }
   else if (_waccess(strOutputFile, 00) == 0)
   {
      ::ShellExecuteW(nullptr, L"open", strOutputFile.GetString(), nullptr, nullptr, SW_SHOWNORMAL);
   }
}

void CSQPPlusCOMSampleDlg::ProcessContinousProject()
{
   IModelPtr               pModel;                    // Will contain the handle to a model.
   IPreparePredictionPtr   pPrepPred;                 // Will contain the data to prepare the prediction
   IPredictionPtr          pPrediction;               // Will contain the handle to the predictions.
   IVariableVectorPtr      pVarVec;
   IVariablePtr            pVariable;
   IStringVectorPtr        pstrVecNamesSettings;      // Holds the settings for a given qualitative variable.

   _bstr_t                 strModelName;              // Will contain a model name
   _bstr_t                 strVarName;                // Will contain the name of a qualitative variable
   _bstr_t                 strVarNameSettings;        // Will contain the name of a qualitative variable

   int                     iNumObs = 10;              // Number of observations to predict.

   // Get the number of models in the project.
   int lNumModels = mpProject->GetNumberOfModels();
   fwprintf(mpFile, L"Number of models: %d\n\n", lNumModels);

   for (long lModelIndex = 1; lModelIndex <= lNumModels; ++lModelIndex)
   {
      /////////////////////////////////////////////////
      // Get the model number connected with this index.
      //
      const auto nModelNumber = mpProject->GetModelNumberFromIndex(lModelIndex);

      pModel = mpProject->GetModel(nModelNumber);

      if (!pModel->IsModelFitted())
      {
         // The model is not fitted, go to the next model.
         fwprintf(mpFile, L"Model is not fitted.\n");
      }
      else
      {
         /////////////////////////////////////////////////
         // Get the name of the model connected with this index.
         //
         // Note: Model number is input NOT model index from now on.
         //
         strModelName = pModel->GetModelName();
         fwprintf(mpFile, L"==================================\n");
         fwprintf(mpFile, L"Model name: %s\n", static_cast<const wchar_t*>(strModelName));

         /////////////////////////////////////////////////
         // Create the prediction data
         pPrepPred = pModel->PreparePrediction();
         pVarVec = pPrepPred->GetVariablesForPrediction();
         int nVariables = pVarVec->GetSize();
         float fVal = 0;
         for (int iVar = 1; iVar <= nVariables; ++iVar)
         {
            pVariable = pVarVec->GetVariable(iVar);
            strVarName = pVariable->GetName(1);
            BOOL bIsLagged = pVariable->IsLagged();
            BOOL bIsQualiative = pVariable->IsQualitative();

            // Get qualitative information
            if (bIsQualiative)
            {
               pstrVecNamesSettings = pVariable->GetQualitativeSettings();
               strVarNameSettings = pstrVecNamesSettings->GetData(1); // Just take the first one
            }

            // Set lag information
            if (bIsLagged)
            {
               IIntVectorPtr pLagSteps;
               int iNumLags;
               int iMaxLag = 0;
               pLagSteps = pVariable->GetLagSteps();
               iNumLags = pLagSteps->GetSize();

               for (int iLag = 1; iLag <= iNumLags; ++iLag)
               {
                  int iLagStep;
                  iLagStep = pLagSteps->GetData(iLag);
                  iMaxLag = max(iMaxLag, iLagStep);

                  // Print the lagged variable
                  fwprintf(mpFile, L"%s.L%d\t", static_cast<const wchar_t*>(strVarName), iLagStep);
               }
               for (int iLag = 1; iLag <= iMaxLag; ++iLag)
               {
                  if (bIsQualiative)
                     pPrepPred->SetQualitativeLagData(pVariable, iLag, strVarNameSettings); // Just take the first value
                  else
                     pPrepPred->SetQuantitativeLagData(pVariable, iLag, fVal++); // Just set to a fake value
               }
            }
            else // Set non lagged prediction data
            {
               // Print the name of the variable
               if (bIsQualiative)
                  fwprintf(mpFile, L"Qualitative variable: %s\n", static_cast<const wchar_t*>(strVarName));
               else
                  fwprintf(mpFile, L"Quantitative variable: %s\n", static_cast<const wchar_t*>(strVarName));

               for (int iRow = 1; iRow <= iNumObs; ++iRow)
               {
                  if (bIsQualiative)
                     pPrepPred->SetQualitativeData(iRow, iVar, strVarNameSettings); // Just take the first value
                  else
                     pPrepPred->SetQuantitativeData(iRow, iVar, fVal++); // Just set to a fake value
               }
            }
         }

         /////////////////////////////////////////////////
         // Make the prediction
         pPrediction = pPrepPred->GetPrediction();

         fwprintf(mpFile, L"\n");
         // Get the "static" model data
         GetModelParameters(pModel);
         fwprintf(mpFile, L"----------------------------------\n");
         fwprintf(mpFile, L"Predicted Result: \n");

         // Get the predicted result
         GetResults(pModel, pPrediction);
      }
   }
}

void CSQPPlusCOMSampleDlg::ProcessBatchProject()
{
   IBatchModelPtr pBatchModel;
   _bstr_t strName;
   const int iNumObs = 10; // Number of observations to predict.
   float fVal = 0;

   int iNumBatchModels = mpProject->GetNumberOfBatchModels();

   for (long iBatchModel = 1; iBatchModel <= iNumBatchModels; ++iBatchModel)
   {
      pBatchModel = mpProject->GetBatchModel(iBatchModel);
      int iNumBEM = pBatchModel->GetNumberOfBEM();
      int iNumBLM = pBatchModel->GetNumberOfBLM();
      if (iNumBLM == 0) // Only predict those who have a batch level model.
         continue;

      // Print Batch level model info
      for (int iBLMIx = 1; iBLMIx <= iNumBLM; ++iBLMIx)
      {
         IBatchLevelModelPtr pBLModel = pBatchModel->GetBatchLevelModel(pBatchModel->GetBatchLevelModelNumber(iBLMIx));

         if (!pBLModel->IsModelFitted())
         {
            // The model is not fitted, go to the next model.
            fwprintf(mpFile, L"Model is not fitted.\n");
         }
         else
         {
            // Get the name of the model connected with this index.
            strName = pBLModel->GetModelName();
            fwprintf(mpFile, L"==================================\n");
            fwprintf(mpFile, L"Model name: %s\n", static_cast<const wchar_t*>(strName));

            // Create a fake prediction set
            IPrepareBatchPredictionPtr pPrepareBatchPrediction = pBLModel->GetPrepareBatchPrediction();
            for (int iPhase = 1; iPhase <= iNumBEM; ++iPhase)
            {
               const int iModelNumber = pBatchModel->GetBatchEvolutionModelNumber(iPhase);
               IVariableVectorPtr pVariableVector = pPrepareBatchPrediction->GetVariablesForBatchPrediction(iPhase);
               const int iNumVariables = pVariableVector->GetSize();
               fwprintf(mpFile, L"Variables for model %d:\n", iModelNumber);

               for (int iVariable = 1; iVariable <= iNumVariables; ++iVariable)
               {
                  IStringVectorPtr pQualSettings;
                  int iNumQualSettings = 0;
                  IVariablePtr pVariable = pVariableVector->GetVariable(iVariable);
                  BOOL bIsQual = pVariable->IsQualitative();
                  strName = pVariable->GetName(1);

                  if (bIsQual == TRUE)
                  {
                     pQualSettings = pVariable->GetQualitativeSettings();
                     iNumQualSettings = pQualSettings->GetSize();
                  }

                  fwprintf(mpFile, L"%s\t", static_cast<const wchar_t*>(strName));
                  for (int iRow = 1; iRow <= iNumObs; ++iRow)
                  {
                     if (bIsQual == TRUE)
                     {
                        int iSetting = (rand() % (iNumQualSettings - 1) + 1); // Rand a number from 1 to iNumQualSettings.

                        strName = pQualSettings->GetData(iSetting);
                        pPrepareBatchPrediction->SetQualitativeBatchData(iPhase, iRow, iVariable, strName);
                     }
                     else
                     {
                        pPrepareBatchPrediction->SetQuantitativeBatchData(iPhase, iRow, iVariable, ++fVal);
                     }
                  }
               }
            }

            // Make the prediction
            IBatchPredictionPtr pPrediction = pPrepareBatchPrediction->GetBatchPrediction();

            // Print the result
            for (int iPhase = 1; iPhase <= iNumBEM; ++iPhase)
            {
               const int iModelNumber = pBatchModel->GetBatchEvolutionModelNumber(iPhase);
               IBatchEvolutionModelPtr pBEModel = pBatchModel->GetBatchEvolutionModel(iModelNumber);
               IBatchEvolutionPredictionPtr pBEPrediction = pPrediction->GetBatchEvolutionPrediction(iModelNumber);
               fwprintf(mpFile, L"\n");
               // Get the "static" model data
               GetModelParameters(pBEModel);
               GetAlignedParameters(pBEModel);
               fwprintf(mpFile, L"----------------------------------\n");
               fwprintf(mpFile, L"Predicted Result: \n");

               // Get the predicted result
               GetResults(pBEModel, pBEPrediction);
               GetAlignedResults(pBEModel, pBEPrediction);
            }
         }
      }
   }
}

void CSQPPlusCOMSampleDlg::GetModelParameters(IModelPtr pModel)
{
   long        lNumComp = 0;              // Will contain number of components in the project.
   ModelType   eModelType;                // Will contain the type of the model.

   IVectorDataPtr      pData;             // Result data

   fwprintf(mpFile, L"----------------------------------\n");
   fwprintf(mpFile, L"Model data for model number %d\n", pModel->GetModelNumber());

   lNumComp = pModel->GetNumberOfComponents();
   fwprintf(mpFile, L"Number of components is %ld.\n", lNumComp);

   // Get the model type for this model
   eModelType = pModel->GetModelType();

   // Get the score single weight contributions from the model.
   pData = pModel->GetContributionsScoresSingleWeight(0 /*iObs1Ix*/, 1/*iObs2Ix*/, eWeight_Normalized, lNumComp, 1, eReconstruct_False /*bReconstruct*/);
   fwprintf(mpFile, L"Contribution SSW:\n");
   PrintVectorData(pData);

   if (lNumComp > 0)
   {
      // GetModelP() is not valid for a zero component model. If it's called for with a model
      // that has zero components it will return false!
      pData = pModel->GetP(NULL /*pComponents*/, eReconstruct_False /*bReconstruct*/);
      fwprintf(mpFile, L"P:\n");
      PrintVectorData(pData);

      // GetModelT() is not valid for a zero component model. If it's called for with a model
      // that has no components it will return false!
      pData = pModel->GetT(NULL /*pComponents*/);
      fwprintf(mpFile, L"T:\n");
      PrintVectorData(pData);
   }

   /* Get T2Range */
   pData = pModel->GetT2Range(1, -1); // -1 equals last component.
   fwprintf(mpFile, L"T2Range:\n");
   PrintVectorData(pData);

   /* Get DModX for all components in the model. */
   pData = pModel->GetDModX(NULL /*pnComponents*/, eNormalized_True /*bNormalized*/, eModelingPowerWeighted_False /*bModelingPowerWeighted*/);
   fwprintf(mpFile, L"DModX:\n");
   PrintVectorData(pData);

   if (IsPredictiveModel(eModelType))
   {
      // C is only valid for a PLS model of some kind.

      // Get C for all components in the model.
      pData = pModel->GetC(NULL /*pComponents*/);
      fwprintf(mpFile, L"C:\n");
      PrintVectorData(pData);
   }
   else
   {
      /* Q2VX is only valid for a PCA model. */

      /* Get Q2VX */
      if (eModelType != eUnDefined) // Bug 14433: SIMCA-Q 14.1 returns eUndefined for O2PLS-DA models.
      {
         pData = pModel->GetQ2VX(NULL);
         fwprintf(mpFile, L"Q2VX:\n");
         PrintVectorData(pData);
      }
   }
}

void CSQPPlusCOMSampleDlg::GetAlignedParameters(IBatchEvolutionModelPtr pBEModel)
{
   ModelType   eModelType;                // Will contain the type of the model.
   IVectorDataPtr      pData;             // Result data.

   // Get the model type for this model
   eModelType = pBEModel->GetModelType();

   if (IsPredictiveModel(eModelType))
   {
      pData = pBEModel->GetAlignedT2Range(1 /*iCompFrom*/, -1 /*iCompTo*/, NULL);
      fwprintf(mpFile, L"AlignedT2Range:\n");
      PrintVectorData(pData);
   }
}

void CSQPPlusCOMSampleDlg::GetResults(IModelPtr pModel, IPredictionPtr pPrediction)
{
   IVectorDataPtr   pData;
   IIntVectorPtr     pnVectorEmpty = mpSQ->GetNewIntVector(0);

   /////////////////////////////////////////////////
   // Get the number of components for this model
   //
   long lNumComp = pModel->GetNumberOfComponents();

   /////////////////////////////////////////////////
   // Get the model type for this model
   //
   ModelType   eModelType = pModel->GetModelType();

   /////////////////////////////////////////////////
   // Get the predicted XVar.
   //
   pData = pPrediction->GetXVarPS(eUnscaled_True, eBacktransformed_True, pnVectorEmpty);
   fwprintf(mpFile, L"XVarPS:\n");
   PrintVectorData(pData);

   /////////////////////////////////////////////////
   // Get scores single weight contribution for the prediction
   //
   // We will request a Single weight score contribution with no weight and
   // the average as the reference. This particular request can only be made for
   // a model that has one or more components.
   //
   if (lNumComp > 0)
   {
      Weight eWeight = eWeight_Normalized;
      long lObs1Ix = 0;        // Use the average as reference.
      long lObs2Ix = 1;        // We want contribution for the first observation.
      long lYVar = 1;
      long lComponent = lNumComp; // For this sample. Set to the number of components in the model.

      pData = pPrediction->GetContributionsScorePSSingleWeight(lObs1Ix, lObs2Ix, eWeight, lComponent, lYVar, eReconstruct_False);
      fwprintf(mpFile, L"Predicted Score Contribution:\n");
      PrintVectorData(pData);
   }

   /////////////////////////////////////////////////
   // Get the predicted DModX.
   //
   pData = pPrediction->GetDModXPS(pnVectorEmpty /*pComponentList*/, eNormalized_True, eModelingPowerWeighted_False);
   fwprintf(mpFile, L"DModXPS:\n");
   PrintVectorData(pData);

   if (lNumComp > 0)
   {
      /////////////////////////////////////////////////
      // Get the predicted T
      //
      // GetPredictedT() is not valid for a zero component model. If it's called for with a model
      // that has no components it will return false!
      pData = pPrediction->GetTPS(pnVectorEmpty /*pComponentList*/);
      fwprintf(mpFile, L"TPS:\n");
      PrintVectorData(pData);

      /////////////////////////////////////////////////
      // Get T2RangePS
      //
      // GetPredictedT2Range() is not valid for a zero component model. If it's called with a model
      // that has no components it will return false!
      pData = pPrediction->GetT2RangePS(1, -1); // -1 equals last component.
      fwprintf(mpFile, L"T2RangePS:\n");
      PrintVectorData(pData);
   }

   if (lNumComp > 0 && IsPredictiveModel(eModelType))
   {
      // Get YPredPS
      pData = pPrediction->GetYPredPS(lNumComp, eUnscaled_False, eBacktransformed_False, NULL /*pnColumnYIndexes*/);
      fwprintf(mpFile, L"YPredPS:\n");
      PrintVectorData(pData);
   }
}

void CSQPPlusCOMSampleDlg::GetAlignedResults(IModelPtr pModel, IBatchEvolutionPredictionPtr pPhasePrediction)
{
   long        lNumComp = 0;              // Will contain number of components in the project.
   ModelType   eModelType;                // Will contain the type of the model.
   IVectorDataPtr      pData;         // Float matrix.
   IFloatVectorPtr      pfVector;         // Float vector.

   lNumComp = pModel->GetNumberOfComponents();

   // Get the model type for this model
   eModelType = pModel->GetModelType();

   if (lNumComp > 0 && IsPredictiveModel(eModelType))
   {
      pData = pPhasePrediction->GetAlignedT2RangePS(1 /*iCompFrom*/, -1 /*iCompTo*/);
      fwprintf(mpFile, L"GetAlignedT2RangePS:\n");
      PrintVectorData(pData);

      pData = pPhasePrediction->GetAlignedTimeMaturityPS();
      fwprintf(mpFile, L"GetAlignedTimeMaturityPS:\n");
      PrintVectorData(pData);
   }

}

void CSQPPlusCOMSampleDlg::PrintVectorData(IVectorDataPtr& pData)
{
   IFloatMatrixPtr pfMat = pData->GetDataMatrix();
   float fVal;
   for (int iColIter = 1; iColIter <= pfMat->GetNumberOfCols(); iColIter++)
   {
      for (int iRowIter = 1; iRowIter <= pfMat->GetNumberOfRows(); iRowIter++)
      {
         fVal = pfMat->GetData(iRowIter, iColIter);
         fwprintf(mpFile, L"%f\t", fVal);
      }
      fwprintf(mpFile, L"\n");
   }
   fwprintf(mpFile, L"\n");
}